-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU] Matrix load/store subgroup distribution #165008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][XeGPU] Matrix load/store subgroup distribution #165008
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesThis PR enables sg-to-wi distribution of xegpu matrix load/store ops. Full diff: https://github.com/llvm/llvm-project/pull/165008.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc196c0bf7..fe059bb86eba2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+ constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+ int operandIdx{-1};
+
+ VectorType payloadTy;
+ VectorType warpResultTy;
+ if constexpr (isLoad) {
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ operandIdx = producedByLastLoad->getOperandNumber();
+ payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+ warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ } else {
+ payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ }
+ if (!payloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp,
+ "The matrix op payload has no layouts, using defaults instead.");
+
+ SmallVector<Value> operands;
+ if constexpr (isLoad)
+ operands = {matrixOp.getMemDesc()};
+ else
+ operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ if constexpr (!isLoad)
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
+ newOperands[operandIdxToModify] = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
+ newWarpOp.getLaneid());
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ if constexpr (isLoad) {
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], newOffsets, newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ } else {
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
+ xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ }
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- GpuBarrierDistribution, VectorMultiReductionDistribution,
- LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
- MemrefExtractAlignedPointerAsIndexDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/regularPatternBenefit);
+ patterns
+ .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+ DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
+ VectorMultiReductionDistribution, LoadDistribution,
+ StoreDistribution, VectorTransposeDistribution,
+ VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
+ MatrixOpDistribution<xegpu::StoreMatrixOp>,
+ MemrefExtractAlignedPointerAsIndexDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/regularPatternBenefit);
patterns.add<VectorShapeCastDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..3fcc747217c9d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return
+ }
+}
|
|
This PR is currently WIP, until there is more clarity on the offset adjustment. Is the logic going to be the same as wg-to-sg (i.e., delinearization based on lane_layout + lane_data)? Is UPD: offset calculation clarified |
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Will approve after the feedback.
| VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); | ||
| if (!valOrResVecTy) | ||
| valOrResVecTy = VectorType::get(1, data.getType()); | ||
| if (valOrResVecTy.getShape().size() != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a verification here for load/store_matrix with @subgroup_block_io attribute: The payload must be contiguous in the memory.
Both of these two IRs in the tests added in this PR are actually not correct. Since the payload data are not contiguous between lanes. They are correct if you change the vector<2x16xf32> to <16x2xf32> (lane_layout/lane_data need to change accordingly but that is out of IR verifier's scope).
%1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, indexThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added verification. However,
change the vector<2x16xf32> to <16x2xf32>
I understand the logical reasoning for this in the matrix ops case, but the current distribution does not allow it, considering the "correct" lane layout the block load requires.
We have
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
if (i < distributionStart)
continue;
// Check if the dimension can be distributed evenly.
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
return failure();
distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
}Meaning that given lane_layout = [1, 16], lane_data = [1, 1] and a 16x2 data shape, we get
shape[0] % layout[0] = 16 % 1 = 0 // good
shape[1] % layout[1] = 2 % 16 = 2 // fail
We can change the layout to be [16, 1], which would allow the pattern to complete and the distributed code to still be correct, since the lane layout is not used in further coordinate calculations. But [16, 1] may be harder for users to reason about by simply looking at the xevm block load description and the sg-level subgroup_block_io matrix op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If user uses stride=[1, 32] in the memory layout, then user should able to reason sg_layout = [16, 1].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if user use lane_layout = [1, 16], it should not use strided memory layout, the example above should just use block layout. The maxtrix op with subgroup_block_io is a subgroup operation, and all lanes collectively access a contiguous memory buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is now a test with a 16x2xf32 result using the proper stride.
Short snippet:
%1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
It distributes to 1x2xf32
| ArrayRef<int64_t> sizePerSg, | ||
| ArrayRef<int64_t> sizePerWg) { | ||
|
|
||
| genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the function should be genCoords(). Inside the function, when we compute each individual value, it is fine to still use offset.
A coordinate (coord) in n-d tensor is a vector of logical offsets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually reading the code below, almost all "offsets" variable (vector of value) can be renamed to "coord".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed.
|
|
||
| auto layout = matrixOp.getLayoutAttr(); | ||
| if (!layout) | ||
| return rewriter.notifyMatchFailure( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add one verifier here. For operation without subgroup_block_io, the lane_data must be physically contiguous in the slm.
// this is correct
%1 = xegpu.load_matrix %arg0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, index, index -> vector<2x16xf32>
// this is not correct.
%1 = xegpu.load_matrix %arg0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
!xegpu.mem_desc<32x32xf32>, index, index -> vector<2x16xf32>
For operation with subgroup_block_io, the lane_data must be [1, 1].
See https://github.com/llvm/llvm-project/pull/158118/files#diff-151653efd45c964d5c1af1ceb208f2eb55a344a985e874b922495bf3e4256c81R213
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a verifier check for both linear and coalesced data.
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some more comments about the verifier.
| VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); | ||
| if (!valOrResVecTy) | ||
| valOrResVecTy = VectorType::get(1, data.getType()); | ||
| if (valOrResVecTy.getShape().size() != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If user uses stride=[1, 32] in the memory layout, then user should able to reason sg_layout = [16, 1].
| VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); | ||
| if (!valOrResVecTy) | ||
| valOrResVecTy = VectorType::get(1, data.getType()); | ||
| if (valOrResVecTy.getShape().size() != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if user use lane_layout = [1, 16], it should not use strided memory layout, the example above should just use block layout. The maxtrix op with subgroup_block_io is a subgroup operation, and all lanes collectively access a contiguous memory buffer.
| [](int x) { return x == 1; }); | ||
| if (!isLaneDataLinear) | ||
| return emitError() | ||
| << "With subgroup_block_io, lane data must be linear."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are checking the tensor tile being accessed by matrix op (with subgroup_block_io) is contiguous. It means we need to check the following :
- lanedata = [ 1, ..1] // 1 for each dim
- for dim x where lane_layout [x]! = 1, stride[x] must be 1, and block[x] must equal to lane_layout[x] if layout is blocked.
Why call it "linear"? Can the error message just be "With subgroup_block_io, tensor tile accessed by subgroup must be contiguous"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed. I use literal strides from the attribute, not the getStrideShape.
6344048 to
10448e1
Compare
10448e1 to
c99294a
Compare
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR enables sg-to-wi distribution of xegpu matrix load/store ops.